# load libraries

import pinecone
import requests, os, zipfile
from torchvision import datasets
import random
from torchvision import transforms as ts
import torchvision.models as models

import matplotlib.pyplot as plt
from PIL import Image

from tqdm.autonotebook import tqdm
import pandas as pd

import pinecone.graph
import pinecone.service
import pinecone.connector

api_key = "54b863a2-57da-4909-a10f-0066710231c0"
pinecone.init(api_key=api_key)
#define functions
class ImageEmbedder:
    def __init__(self):
        self.normalize = ts.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
        # see https://pytorch.org/vision/0.8/models.html for many more model options
        self.model = models.squeezenet1_0(pretrained=True)  # squeezenet
    
    def embed(self, image_file_name):
        image = Image.open(image_file_name).convert("RGB")
        image = ts.Resize(256)(image)
        image = ts.CenterCrop(224)(image)
        tensor = ts.ToTensor()(image)
        tensor = self.normalize(tensor).reshape(1, 3, 224, 224)
        vector = self.model(tensor).cpu().detach().numpy().flatten()
        return vector


image_embedder = ImageEmbedder()


def show_images_horizontally(file_names):
    m = len(file_names)
    fig, ax = plt.subplots(1, m)
    fig.set_figwidth(1.5 * m)
    for a, f in zip(ax, file_names):
        a.imshow(Image.open(f))
        a.axis("off")
    plt.show()


def show_image(file_name):
    fig, ax = plt.subplots(1, 1)
    fig.set_figwidth(1.3)
    ax.imshow(Image.open(file_name))
    ax.axis("off")
    plt.show()

manage directories (so much easier in R)

root = rprojroot::find_root(rprojroot::is_git_root)

train_dir = file.path(root, "images","train")
test_dir = file.path(root, "images","test")
df_train = pd.DataFrame()
df_test = pd.DataFrame()

train_files = os.listdir(r.train_dir)
test_files = os.listdir(r.test_dir)

train_file_names = list(map(lambda name: os.path.join(r.train_dir, name), train_files))
test_file_names = list(map(lambda name: os.path.join(r.test_dir, name), test_files))


df_train["image_file_name"] = train_file_names
df_train["embedding_id"] = [
    file_name.split(r.train_dir)[-1] for file_name in train_file_names
]
df_train["embedding"] = [
    image_embedder.embed(file_name)
    for file_name in tqdm(train_file_names)
]
## 
  0%|          | 0/22 [00:00<?, ?it/s]
  9%|9         | 2/22 [00:00<00:01, 12.11it/s]
 18%|#8        | 4/22 [00:00<00:01, 12.49it/s]
 27%|##7       | 6/22 [00:00<00:01, 11.65it/s]
 36%|###6      | 8/22 [00:00<00:01, 12.65it/s]
 45%|####5     | 10/22 [00:00<00:00, 13.14it/s]
 55%|#####4    | 12/22 [00:00<00:00, 12.06it/s]
 64%|######3   | 14/22 [00:01<00:00, 11.12it/s]
 73%|#######2  | 16/22 [00:01<00:00, 11.09it/s]
 82%|########1 | 18/22 [00:01<00:00, 11.26it/s]
 91%|######### | 20/22 [00:01<00:00, 11.40it/s]
100%|##########| 22/22 [00:01<00:00, 11.84it/s]
100%|##########| 22/22 [00:01<00:00, 11.80it/s]
df_test["image_file_name"] = test_file_names
df_test["embedding_id"] = [
    file_name.split(r.test_dir)[-1] for file_name in test_file_names
]
df_test["embedding"] = [
    image_embedder.embed(file_name)
    for file_name in tqdm(test_file_names)
]
## 
  0%|          | 0/15 [00:00<?, ?it/s]
  7%|6         | 1/15 [00:00<00:01,  7.96it/s]
 13%|#3        | 2/15 [00:00<00:01,  7.70it/s]
 20%|##        | 3/15 [00:00<00:02,  5.82it/s]
 27%|##6       | 4/15 [00:00<00:02,  4.95it/s]
 40%|####      | 6/15 [00:00<00:01,  7.26it/s]
 47%|####6     | 7/15 [00:01<00:01,  7.05it/s]
 53%|#####3    | 8/15 [00:01<00:00,  7.08it/s]
 60%|######    | 9/15 [00:01<00:00,  6.98it/s]
 67%|######6   | 10/15 [00:01<00:00,  5.32it/s]
 73%|#######3  | 11/15 [00:01<00:00,  4.72it/s]
 80%|########  | 12/15 [00:02<00:00,  4.94it/s]
 93%|#########3| 14/15 [00:02<00:00,  7.15it/s]
100%|##########| 15/15 [00:02<00:00,  6.51it/s]
# Choosing an arbitrary name for my service
service_name = "simple-pytorch-image-search"

# Checking whether the service is already deployed.
if service_name not in pinecone.service.ls():
    graph = pinecone.graph.IndexGraph(metric="euclidean", shards=1)
    pinecone.service.deploy(service_name, graph)

conn = pinecone.connector.connect(service_name)
conn.info()
## InfoResult(index_size=22)
acks = conn.upsert(items=zip(df_train.embedding_id, df_train.embedding)).collect()
## 
0it [00:00, ?it/s]
1it [00:00,  1.30it/s]
22it [00:00, 28.57it/s]
conn.info()
## InfoResult(index_size=22)
res = conn.query(df_test.embedding, batch_size=15).collect()  # issuing queries
## 
0it [00:00, ?it/s]
1it [00:00,  1.02it/s]
15it [00:00, 15.28it/s]
for i in range(0, 14):
    print(f"Query {i+1} and search results")
    show_image(df_test.image_file_name.iloc[i])
    show_images_horizontally(
        [r.train_dir + embedding_id for embedding_id in res[i].ids]
    )

    print("-" * 80)
## Query 1 and search results
## --------------------------------------------------------------------------------
## Query 2 and search results
## --------------------------------------------------------------------------------
## Query 3 and search results
## --------------------------------------------------------------------------------
## Query 4 and search results
## --------------------------------------------------------------------------------
## Query 5 and search results
## --------------------------------------------------------------------------------
## Query 6 and search results
## --------------------------------------------------------------------------------
## Query 7 and search results
## --------------------------------------------------------------------------------
## Query 8 and search results
## --------------------------------------------------------------------------------
## Query 9 and search results
## --------------------------------------------------------------------------------
## Query 10 and search results
## --------------------------------------------------------------------------------
## Query 11 and search results
## --------------------------------------------------------------------------------
## Query 12 and search results
## --------------------------------------------------------------------------------
## Query 13 and search results
## --------------------------------------------------------------------------------
## Query 14 and search results
## --------------------------------------------------------------------------------
## 
## C:/Users/Aubur/AppData/Local/r-miniconda/envs/r-reticulate/python.exe:2: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).
## Traceback (most recent call last):
##   File "C:\Users\Aubur\AppData\Local\R-MINI~1\envs\R-RETI~1\lib\site-packages\matplotlib\backends\backend_qt5.py", line 480, in _draw_idle
##     self.draw()
##   File "C:\Users\Aubur\AppData\Local\R-MINI~1\envs\R-RETI~1\lib\site-packages\matplotlib\backends\backend_agg.py", line 407, in draw
##     self.figure.draw(self.renderer)
##   File "C:\Users\Aubur\AppData\Local\R-MINI~1\envs\R-RETI~1\lib\site-packages\matplotlib\artist.py", line 41, in draw_wrapper
##     return draw(artist, renderer, *args, **kwargs)
##   File "C:\Users\Aubur\AppData\Local\R-MINI~1\envs\R-RETI~1\lib\site-packages\matplotlib\figure.py", line 1864, in draw
##     renderer, self, artists, self.suppressComposite)
##   File "C:\Users\Aubur\AppData\Local\R-MINI~1\envs\R-RETI~1\lib\site-packages\matplotlib\image.py", line 131, in _draw_list_compositing_images
##     a.draw(renderer)
##   File "C:\Users\Aubur\AppData\Local\R-MINI~1\envs\R-RETI~1\lib\site-packages\matplotlib\artist.py", line 41, in draw_wrapper
##     return draw(artist, renderer, *args, **kwargs)
##   File "C:\Users\Aubur\AppData\Local\R-MINI~1\envs\R-RETI~1\lib\site-packages\matplotlib\cbook\deprecation.py", line 411, in wrapper
##     return func(*inner_args, **inner_kwargs)
##   File "C:\Users\Aubur\AppData\Local\R-MINI~1\envs\R-RETI~1\lib\site-packages\matplotlib\axes\_base.py", line 2707, in draw
##     self._update_title_position(renderer)
##   File "C:\Users\Aubur\AppData\Local\R-MINI~1\envs\R-RETI~1\lib\site-packages\matplotlib\axes\_base.py", line 2636, in _update_title_position
##     if (ax.xaxis.get_ticks_position() in ['top', 'unknown']
##   File "C:\Users\Aubur\AppData\Local\R-MINI~1\envs\R-RETI~1\lib\site-packages\matplotlib\axis.py", line 2207, in get_ticks_position
##     self._get_ticks_position()]
##   File "C:\Users\Aubur\AppData\Local\R-MINI~1\envs\R-RETI~1\lib\site-packages\matplotlib\axis.py", line 1893, in _get_ticks_position
##     minor = self.minorTicks[0]
## IndexError: list index out of range